import argparse
import os
import sys
import pickle
import math

import torch
import numpy as np
from torchvision import utils

from models.StyleGAN2.model import Generator, Discriminator


def convert_modconv(vars, source_name, target_name, flip=False):
    weight = vars[source_name + '/weight'].value().eval()
    mod_weight = vars[source_name + '/mod_weight'].value().eval()
    mod_bias = vars[source_name + '/mod_bias'].value().eval()
    noise = vars[source_name + '/noise_strength'].value().eval()
    bias = vars[source_name + '/bias'].value().eval()

    dic = {
        'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
        'conv.modulation.weight': mod_weight.transpose((1, 0)),
        'conv.modulation.bias': mod_bias + 1,
        'noise.weight': np.array([noise]),
        'activate.bias': bias,
    }

    dic_torch = {}

    for k, v in dic.items():
        dic_torch[target_name + '.' + k] = torch.from_numpy(v)

    if flip:
        dic_torch[target_name + '.conv.weight'] = torch.flip(
            dic_torch[target_name + '.conv.weight'], [3, 4]
        )

    return dic_torch


def convert_conv(vars, source_name, target_name, bias=True, start=0):
    weight = vars[source_name + '/weight'].value().eval()
    
    dic = {'weight': weight.transpose((3, 2, 0, 1))}
    
    if bias:
        dic['bias'] = vars[source_name + '/bias'].value().eval()
    
    dic_torch = {}
    dic_torch[target_name + '.{}.weight'.format(start)] = torch.from_numpy(dic['weight'])
    
    if bias:
        dic_torch[target_name + '.{}.bias'.format(start + 1)] = torch.from_numpy(dic['bias'])
        
    return dic_torch


def convert_torgb(vars, source_name, target_name):
    weight = vars[source_name + '/weight'].value().eval()
    mod_weight = vars[source_name + '/mod_weight'].value().eval()
    mod_bias = vars[source_name + '/mod_bias'].value().eval()
    bias = vars[source_name + '/bias'].value().eval()

    dic = {
        'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
        'conv.modulation.weight': mod_weight.transpose((1, 0)),
        'conv.modulation.bias': mod_bias + 1,
        'bias': bias.reshape((1, 3, 1, 1)),
    }

    dic_torch = {}

    for k, v in dic.items():
        dic_torch[target_name + '.' + k] = torch.from_numpy(v)

    return dic_torch


def convert_dense(vars, source_name, target_name):
    weight = vars[source_name + '/weight'].value().eval()
    bias = vars[source_name + '/bias'].value().eval()

    dic = {'weight': weight.transpose((1, 0)), 'bias': bias}

    dic_torch = {}

    for k, v in dic.items():
        dic_torch[target_name + '.' + k] = torch.from_numpy(v)

    return dic_torch


def update(state_dict, new):
    for k, v in new.items():
        if k not in state_dict:
            raise KeyError(k + ' is not found')

        if v.shape != state_dict[k].shape:
            raise ValueError('Shape mismatch: {} vs {}'.format(v.shape, state_dict[k].shape))

        state_dict[k] = v
        
        
def discriminator_fill_statedict(statedict, vars, size):
    log_size = int(math.log(size, 2))
    
    update(statedict, convert_conv(vars, '{}x{}/FromRGB'.format(size, size), 'convs.0'))

    conv_i = 1

    for i in range(log_size - 2, 0, -1):
        reso = 4 * 2 ** i
        update(statedict, convert_conv(vars, '{}x{}/Conv0'.format(reso, reso), 'convs.{}.conv1'.format(conv_i)))
        update(statedict, convert_conv(vars, '{}x{}/Conv1_down'.format(reso, reso), 'convs.{}.conv2'.format(conv_i), start=1))
        update(statedict, convert_conv(vars, '{}x{}/Skip'.format(reso, reso), 'convs.{}.skip'.format(conv_i), start=1, bias=False))
        conv_i += 1
        
    update(statedict, convert_conv(vars, '4x4/Conv', 'final_conv'))
    update(statedict, convert_dense(vars, '4x4/Dense0', 'final_linear.0'))
    update(statedict, convert_dense(vars, 'Output', 'final_linear.1'))
    
    return statedict


def fill_statedict(state_dict, vars, size):
    log_size = int(math.log(size, 2))

    for i in range(8):
        update(state_dict, convert_dense(vars, 'G_mapping/Dense{}'.format(i), 'style.{}'.format(i + 1)))

    update(
        state_dict,
        {
            'input.input': torch.from_numpy(
                vars['G_synthesis/4x4/Const/const'].value().eval()
            )
        },
    )

    update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1'))

    for i in range(log_size - 2):
        reso = 4 * 2 ** (i + 1)
        update(
            state_dict,
            convert_torgb(vars, 'G_synthesis/{}x{}/ToRGB'.format(reso, reso), 'to_rgbs.{}'.format(i)),
        )

    update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1'))

    conv_i = 0

    for i in range(log_size - 2):
        reso = 4 * 2 ** (i + 1)
        update(
            state_dict,
            convert_modconv(
                vars,
                'G_synthesis/{}x{}/Conv0_up'.format(reso, reso),
                'convs.{}'.format(conv_i),
                flip=True,
            ),
        )
        update(
            state_dict,
            convert_modconv(
                vars,
                'G_synthesis/{}x{}/Conv1'.format(reso, reso),
                'convs.{}'.format(conv_i +  1)
            ),
        )
        conv_i += 2

    for i in range(0, (log_size - 2) * 2 + 1):
        update(
            state_dict,
            {
                'noises.noise_{}'.format(i): torch.from_numpy(
                    vars['G_synthesis/noise{}'.format(i)].value().eval()
                )
            },
        )

    return state_dict


if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()
    parser.add_argument('--repo', type=str, required=True)
    parser.add_argument('--gen', action='store_true')
    parser.add_argument('--disc', action='store_true')
    parser.add_argument('path', metavar='PATH')

    args = parser.parse_args()

    sys.path.append(args.repo)

    import dnnlib
    from dnnlib import tflib

    tflib.init_tf()

    with open(args.path, 'rb') as f:
        generator, discriminator, g_ema = pickle.load(f)

    size = g_ema.output_shape[2]

    g = Generator(size, 512, 8)
    state_dict = g.state_dict()
    state_dict = fill_statedict(state_dict, g_ema.vars, size)

    g.load_state_dict(state_dict)

    latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval())
    
    ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg}
    
    if args.gen:
        g_train = Generator(size, 512, 8)
        g_train_state = g_train.state_dict()
        g_train_state = fill_statedict(g_train_state, generator.vars, size)
        ckpt['g'] = g_train_state
        
    if args.disc:
        disc = Discriminator(size)
        d_state = disc.state_dict()
        d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
        ckpt['d'] = d_state

    name = os.path.splitext(os.path.basename(args.path))[0]
    torch.save(ckpt, name + '.pt')

    batch_size = {256: 16, 512: 9, 1024: 4}
    n_sample = batch_size.get(size, 25)

    g = g.to(device)

    z = np.random.RandomState(0).randn(n_sample, 512).astype('float32')

    with torch.no_grad():
        img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device))

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.randomize_noise = False
    img_tf = g_ema.run(z, None, **Gs_kwargs)
    img_tf = torch.from_numpy(img_tf).to(device)

    img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0)

    img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
    utils.save_image(img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1))

